"""
Measuring Massive Multitask Language Understanding
Dan Hendrycks, Collin Burns, Steven Basart, Andy Zou, Mantas Mazeika, Dawn Song, Jacob Steinhardt
https://arxiv.org/abs/2009.03300
"""

import random
import re

import requests

import blobfile as bf
import pandas

from . import common
from .common import (
  HTML_JINJA,
  MULTILINGUAL_ANSWER_PATTERN_TEMPLATE,
  MULTILINGUAL_ANSWER_REGEXES,
  format_multichoice_question,
  normalize_extracted_answer,
  normalize_response,
)
from .eval_types import Eval, EvalResult, SamplerBase, SingleEvalResult

subject2category = {
  "abstract_algebra": "stem",
  "anatomy": "other",
  "astronomy": "stem",
  "business_ethics": "other",
  "clinical_knowledge": "other",
  "college_biology": "stem",
  "college_chemistry": "stem",
  "college_computer_science": "stem",
  "college_mathematics": "stem",
  "college_medicine": "other",
  "college_physics": "stem",
  "computer_security": "stem",
  "conceptual_physics": "stem",
  "econometrics": "social_sciences",
  "electrical_engineering": "stem",
  "elementary_mathematics": "stem",
  "formal_logic": "humanities",
  "global_facts": "other",
  "high_school_biology": "stem",
  "high_school_chemistry": "stem",
  "high_school_computer_science": "stem",
  "high_school_european_history": "humanities",
  "high_school_geography": "social_sciences",
  "high_school_government_and_politics": "social_sciences",
  "high_school_macroeconomics": "social_sciences",
  "high_school_mathematics": "stem",
  "high_school_microeconomics": "social_sciences",
  "high_school_physics": "stem",
  "high_school_psychology": "social_sciences",
  "high_school_statistics": "stem",
  "high_school_us_history": "humanities",
  "high_school_world_history": "humanities",
  "human_aging": "other",
  "human_sexuality": "social_sciences",
  "international_law": "humanities",
  "jurisprudence": "humanities",
  "logical_fallacies": "humanities",
  "machine_learning": "stem",
  "management": "other",
  "marketing": "other",
  "medical_genetics": "other",
  "miscellaneous": "other",
  "moral_disputes": "humanities",
  "moral_scenarios": "humanities",
  "nutrition": "other",
  "philosophy": "humanities",
  "prehistory": "humanities",
  "professional_accounting": "other",
  "professional_law": "humanities",
  "professional_medicine": "other",
  "professional_psychology": "social_sciences",
  "public_relations": "social_sciences",
  "security_studies": "social_sciences",
  "sociology": "social_sciences",
  "us_foreign_policy": "social_sciences",
  "virology": "other",
  "world_religions": "humanities",
}


class MMLUEval(Eval):
  def __init__(self, num_examples: int | None = None, language: str = "EN-US"):
    if language != "EN-US":
      url = (
        f"https://openaipublic.blob.core.windows.net/simple-evals/mmlu_{language}.csv"
      )
    else:
      url = "https://openaipublic.blob.core.windows.net/simple-evals/mmlu.csv"
    df = pandas.read_csv(url)
    examples = [row.to_dict() for _, row in df.iterrows()]
    if num_examples:
      examples = random.Random(0).sample(examples, num_examples)
    self.examples = examples

  def __call__(self, sampler: SamplerBase) -> EvalResult:
    def fn(row: dict):
      prompt_messages = [
        sampler._pack_message(content=format_multichoice_question(row), role="user")
      ]
      response_text = normalize_response(sampler(prompt_messages))
      extracted_answer = None
      for answer_regex in MULTILINGUAL_ANSWER_REGEXES:
        regex = MULTILINGUAL_ANSWER_PATTERN_TEMPLATE.format(answer_regex)
        match = re.search(regex, response_text)
        if match:
          extracted_answer = normalize_extracted_answer(match.group(1))
          break
      score = 1.0 if extracted_answer == row["Answer"] else 0.0
      html = common.jinja_env.from_string(HTML_JINJA).render(
        prompt_messages=prompt_messages,
        next_message=dict(content=response_text, role="assistant"),
        score=score,
        correct_answer=row["Answer"],
        extracted_answer=extracted_answer,
      )
      convo = prompt_messages + [dict(content=response_text, role="assistant")]
      category = subject2category.get(row["Subject"], "other")
      return SingleEvalResult(
        html=html, score=score, metrics={category: score}, convo=convo
      )

    results = common.map_with_progress(fn, self.examples)
    return common.aggregate_results(results)
